32.2 企业安全配置

18 分钟阅读

32.2.1 身份验证与授权#

API 密钥管理#

集中式密钥管理

class APIKeyManager: """API 密钥管理器"""

def init(self): self.vault_url = os.getenv('VAULT_ADDR') self.vault_token = os.getenv('VAULT_TOKEN') self.key_cache = {} self.cache_ttl = 3600 # 1 hour

def get_key(self, key_name: str) -> str: """获取 API 密钥"""

检查缓存

if key_name in self.key_cache: cached = self.key_cache[key_name] if time.time() - cached['timestamp'] < self.cache_ttl: return cached['key']

从 Vault 获取

key = self._fetch_from_vault(key_name)

缓存密钥

self.key_cache[key_name] = { 'key': key, 'timestamp': time.time()

}

return key

def _fetch_from_vault(self, key_name: str) -> str: """从 Vault 获取密钥""" try: response = requests.get( f"{self.vault_url}/v1/secret/data/{key_name}", headers={'X-Vault-Token': self.vault_token} )

if response.status_code == 200: data = response.json() return data['data']['data']['value'] else: raise Exception(f"Failed to fetch key: {response.status_code}") except Exception as e: logger.error(f"Error fetching key from vault: {e}") raise

密钥轮换策略

bash
python

class KeyRotationManager:
    """密钥轮换管理器"""

    def __init__(self):
        self.rotation_schedule = {}
        self.rotation_history = []

    def schedule_rotation(self, key_name: str,
                        interval_days: int = 90):
        """安排密钥轮换"""
        next_rotation = datetime.now() + timedelta(days=interval_days)
        self.rotation_schedule[key_name] = {
            'interval_days': interval_days,
            'next_rotation': next_rotation,
            'last_rotation': None
        }

        logger.info(f"Scheduled rotation for {key_name} in {interval_days} days")

    def check_rotations(self) -> List[str]:
        """检查需要轮换的密钥"""
        now = datetime.now()
        keys_to_rotate = []

        for key_name, schedule in self.rotation_schedule.items():
            if schedule['next_rotation'] <= now:
                keys_to_rotate.append(key_name)

        return keys_to_rotate

    def rotate_key(self, key_name: str) -> RotationResult:
        """轮换密钥"""
        result = RotationResult(key_name=key_name)

        try:
            # 生成新密钥
            new_key = self._generate_new_key()

            # 更新配置
            self._update_key_configuration(key_name, new_key)

            # 记录轮换
            self.rotation_history.append({
                'key_name': key_name,
                'rotated_at': datetime.now(),
                'old_key_hash': self._hash_key(self._get_old_key(key_name)),
                'new_key_hash': self._hash_key(new_key)
            })

            # 更新轮换计划
            self.rotation_schedule[key_name]['last_rotation'] = datetime.now()
            self.rotation_schedule[key_name]['next_rotation'] = \
                datetime.now() + timedelta(
                    days=self.rotation_schedule[key_name]['interval_days']
                )

            result.success = True
            result.new_key = new_key

        except Exception as e:
            result.success = False
            result.error = str(e)

        return result

### SSO 集成

#### OAuth 2.0 配置

class SSOAuthenticator:
"""SSO 认证器"""
def __init__(self, config: Dict):
self.client_id = config['client_id']
self.client_secret = config['client_secret']
self.redirect_uri = config['redirect_uri']
self.auth_url = config['auth_url']
self.token_url = config['token_url']
self.scopes = config.get('scopes', ['openid', 'profile'])
def get_auth_url(self, state: str = None) -> str:
"""获取认证 URL"""
params = {
'response_type': 'code',
'client_id': self.client_id,
'redirect_uri': self.redirect_uri,
'scope': ' '.join(self.scopes),
'state': state or self._generate_state()
}
return f"{self.auth_url}?{urllib.parse.urlencode(params)}"
def exchange_code_for_token(self,
auth_code: str) -> TokenResponse:
"""用授权码交换访问令牌"""
data = {
'grant_type': 'authorization_code',
'code': auth_code,
'client_id': self.client_id,
'client_secret': self.client_secret,
'redirect_uri': self.redirect_uri
}
response = requests.post(self.token_url, data=data)
if response.status_code == 200:
token_data = response.json()
return TokenResponse(
access_token=token_data['access_token'],
refresh_token=token_data.get('refresh_token'),
expires_in=token_data.get('expires_in', 3600),
token_type=token_data.get('token_type', 'Bearer')
)
else:
raise Exception(f"Token exchange failed: {response.status_code}")
def refresh_access_token(self,
refresh_token: str) -> TokenResponse:
"""刷新访问令牌"""
data = {
'grant_type': 'refresh_token',
'refresh_token': refresh_token,
'client_id': self.client_id,
'client_secret': self.client_secret
}
response = requests.post(self.token_url, data=data)
if response.status_code == 200:
token_data = response.json()
return TokenResponse(
access_token=token_data['access_token'],
refresh_token=token_data.get('refresh_token', refresh_token),
expires_in=token_data.get('expires_in', 3600),
token_type=token_data.get('token_type', 'Bearer')
)
else:
raise Exception(f"Token refresh failed: {response.status_code}")
def _generate_state(self) -> str:
"""生成状态参数"""
return secrets.token_urlsafe(16)

多因素认证 (MFA)#

bash
python

class MFAAuthenticator:
    """MFA 认证器"""

    def __init__(self):
        self.mfa_methods = {
            'totp': self._verify_totp,
            'sms': self._verify_sms,
            'email': self._verify_email
        }

    def verify_mfa(self, method: str,
                   code: str,
                   user_id: str) -> bool:
        """验证 MFA 代码"""
        verifier = self.mfa_methods.get(method)

        if not verifier:
            raise ValueError(f"Unsupported MFA method: {method}")

        return verifier(code, user_id)

    def _verify_totp(self, code: str, user_id: str) -> bool:
        """验证 TOTP 代码"""
        # 获取用户的 TOTP 密钥
        secret = self._get_totp_secret(user_id)

        # 生成预期的代码
        totp = pyotp.TOTP(secret)
        expected_code = totp.now()

        # 验证代码(允许时间窗口)
        return totp.verify(code, valid_window=1)

    def _verify_sms(self, code: str, user_id: str) -> bool:
        """验证 SMS 代码"""
        # 从数据库获取发送的代码
        stored_code = self._get_stored_sms_code(user_id)

        # 验证代码
        return stored_code == code and not self._is_code_expired(user_id)

    def _verify_email(self, code: str, user_id: str) -> bool:
        """验证邮件代码"""
        # 从数据库获取发送的代码
        stored_code = self._get_stored_email_code(user_id)

        # 验证代码
        return stored_code == code and not self._is_code_expired(user_id)

## 32.2.2 权限控制

### 基于角色的访问控制 (RBAC)

class RBACManager:
"""RBAC 管理器"""
def __init__(self):
self.roles = {}
self.permissions = {}
self.user_roles = {}
def define_role(self, role_name: str,
permissions: List[str]):
"""定义角色"""
self.roles[role_name] = permissions
logger.info(f"Role {role_name} defined with {len(permissions)} permissions")
def assign_role(self, user_id: str, role_name: str):
"""为用户分配角色"""
if role_name not in self.roles:
raise ValueError(f"Role {role_name} not defined")
if user_id not in self.user_roles:
self.user_roles[user_id] = []
if role_name not in self.user_roles[user_id]:
self.user_roles[user_id].append(role_name)
logger.info(f"Role {role_name} assigned to user {user_id}")
def check_permission(self, user_id: str,
permission: str) -> bool:
"""检查用户权限"""
user_roles = self.user_roles.get(user_id, [])
for role in user_roles:
role_permissions = self.roles.get(role, [])
if permission in role_permissions:
return True
return False
def get_user_permissions(self, user_id: str) -> List[str]:
"""获取用户的所有权限"""
user_roles = self.user_roles.get(user_id, [])
all_permissions = set()
for role in user_roles:
role_permissions = self.roles.get(role, [])
all_permissions.update(role_permissions)
return list(all_permissions)

权限策略定义#

bash
python

class PermissionPolicy:
    """权限策略"""

    # 定义权限
    PERMISSIONS = {
        'code:generate': 'Generate code',
        'code:read': 'Read code',
        'code:write': 'Write code',
        'code:delete': 'Delete code',
        'file:read': 'Read files',
        'file:write': 'Write files',
        'file:delete': 'Delete files',
        'tool:execute': 'Execute tools',
        'config:manage': 'Manage configuration',
        'user:manage': 'Manage users'
    }

    # 定义角色
    ROLES = {
        'viewer': [
            'code:read',
            'file:read'
        ],
        'developer': [
            'code:read',
            'code:write',
            'code:generate',
            'file:read',
            'file:write',
            'tool:execute'
        ],
        'senior_developer': [
            'code:read',
            'code:write',
            'code:generate',
            'code:delete',
            'file:read',
            'file:write',
            'file:delete',
            'tool:execute'
        ],
        'admin': [
            'code:read',
            'code:write',
            'code:generate',
            'code:delete',
            'file:read',
            'file:write',
            'file:delete',
            'tool:execute',
            'config:manage',
            'user:manage'
        ]
    }

### 权限检查中间件

class PermissionMiddleware:
"""权限检查中间件"""
def __init__(self, rbac_manager: RBACManager):
self.rbac_manager = rbac_manager
def check_permission(self,
user_id: str,
required_permission: str) -> bool:
"""检查权限"""
has_permission = self.rbac_manager.check_permission(
user_id,
required_permission
)
if not has_permission:
logger.warning(
f"Permission denied: user={user_id}, "
f"permission={required_permission}"
)
return has_permission
def require_permission(self, permission: str):
"""权限装饰器"""
def decorator(func):
def wrapper(*args, **kwargs):
# 获取用户 ID
user_id = self._get_user_id()
# 检查权限
if not self.check_permission(user_id, permission):
raise PermissionError(
f"Permission denied: {permission}"
)
# 执行函数
return func(*args, **kwargs)
return wrapper
return decorator
def _get_user_id(self) -> str:
"""获取当前用户 ID"""
# 从上下文或会话中获取
return os.getenv('USER_ID', 'anonymous')

32.2.3 审计日志#

审计日志记录器#

bash
python

class AuditLogger:
    """审计日志记录器"""

    def __init__(self, config: Dict):
        self.log_file = config.get('log_file', 'audit.log')
        self.log_level = config.get('log_level', 'INFO')
        self.retention_days = config.get('retention_days', 90)

        # 配置日志
        self.logger = logging.getLogger('audit')
        self.logger.setLevel(getattr(logging, self.log_level))

        # 文件处理器
        handler = logging.FileHandler(self.log_file)
        handler.setFormatter(
            logging.Formatter(
                '%(asctime)s - %(levelname)s - %(message)s'
            )
        )
        self.logger.addHandler(handler)

    def log_event(self, event: AuditEvent):
        """记录审计事件"""
        log_entry = {
            'timestamp': datetime.utcnow().isoformat(),
            'user_id': event.user_id,
            'action': event.action,
            'resource': event.resource,
            'result': event.result,
            'ip_address': event.ip_address,
            'user_agent': event.user_agent,
            'metadata': event.metadata
        }

        self.logger.info(json.dumps(log_entry))

    def log_api_call(self, user_id: str,
                    endpoint: str,
                    method: str,
                    status_code: int,
                    duration_ms: float):
        """记录 API 调用"""
        event = AuditEvent(
            user_id=user_id,
            action='API_CALL',
            resource=endpoint,
            result=str(status_code),
            metadata={
                'method': method,
                'duration_ms': duration_ms
            }
        )
        self.log_event(event)

    def log_file_access(self, user_id: str,
                      file_path: str,
                      action: str,
                      result: str):
        """记录文件访问"""
        event = AuditEvent(
            user_id=user_id,
            action=f'FILE_{action.upper()}',
            resource=file_path,
            result=result
        )
        self.log_event(event)

    def log_permission_check(self, user_id: str,
                           permission: str,
                           granted: bool):
        """记录权限检查"""
        event = AuditEvent(
            user_id=user_id,
            action='PERMISSION_CHECK',
            resource=permission,
            result='GRANTED' if granted else 'DENIED'
        )
        self.log_event(event)

    def cleanup_old_logs(self):
        """清理旧日志"""
        cutoff_date = datetime.now() - timedelta(days=self.retention_days)

        # 读取日志文件
        with open(self.log_file, 'r') as f:
            lines = f.readlines()

        # 过滤旧日志
        filtered_lines = []
        for line in lines:
            try:
                log_entry = json.loads(line)
                log_date = datetime.fromisoformat(log_entry['timestamp'])
                if log_date > cutoff_date:
                    filtered_lines.append(line)
            except (json.JSONDecodeError, ValueError):
                # 保留无法解析的行
                filtered_lines.append(line)

        # 写回文件
        with open(self.log_file, 'w') as f:
            f.writelines(filtered_lines)

        logger.info(f"Cleaned up audit logs, removed {len(lines) - len(filtered_lines)} entries")

### 审计事件类型

class AuditEvent:
"""审计事件"""
def __init__(self,
user_id: str,
action: str,
resource: str = None,
result: str = None,
ip_address: str = None,
user_agent: str = None,
metadata: Dict = None):
self.user_id = user_id
self.action = action
self.resource = resource
self.result = result
self.ip_address = ip_address or self._get_client_ip()
self.user_agent = user_agent or self._get_user_agent()
self.metadata = metadata or {}
def _get_client_ip(self) -> str:
"""获取客户端 IP"""
# 从请求上下文中获取
return os.getenv('REMOTE_ADDR', 'unknown')
def _get_user_agent(self) -> str:
"""获取用户代理"""
# 从请求头中获取
return os.getenv('HTTP_USER_AGENT', 'unknown')

32.2.4 数据保护#

数据分类#

bash
python

class DataClassifier:
    """数据分类器"""

    def __init__(self):
        self.classification_rules = {
            'public': {
                'description': '可以公开访问的数据',
                'examples': ['public documentation', 'open source code']
            },
            'internal': {
                'description': '仅限内部访问的数据',
                'examples': ['internal documentation', 'proprietary code']
            },
            'confidential': {
                'description': '需要特殊保护的数据',
                'examples': ['customer data', 'financial information']
            },
            'restricted': {
                'description': '最高级别的保护',
                'examples': ['PII', 'trade secrets']
            }
        }

    def classify(self, data: str,
                context: Dict = None) -> str:
        """分类数据"""
        # 检查敏感关键词
        if self._contains_pii(data):
            return 'restricted'

        # 检查上下文
        if context:
            if context.get('source') == 'customer':
                return 'confidential'
            elif context.get('access_level') == 'internal':
                return 'internal'

        # 默认分类
        return 'public'

    def _contains_pii(self, data: str) -> bool:
        """检查是否包含 PII"""
        pii_patterns = [
            r'\b\d{3}-\d{2}-\d{4}\b',  # SSN
            r'\b\d{16}\b',  # Credit card
            r'\b[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Z|a-z]{2,}\b'  # Email
        ]

        for pattern in pii_patterns:
            if re.search(pattern, data):
                return True

        return False

### 数据脱敏

class DataMasker:
"""数据脱敏器"""
def __init__(self):
self.masking_rules = {
'email': self._mask_email,
'phone': self._mask_phone,
'ssn': self._mask_ssn,
'credit_card': self._mask_credit_card,
'ip_address': self._mask_ip_address
}
def mask_data(self, data: str,
data_type: str = 'auto') -> str:
"""脱敏数据"""
if data_type == 'auto':
data_type = self._detect_data_type(data)
masker = self.masking_rules.get(data_type)
if masker:
return masker(data)
else:
return data
def _mask_email(self, email: str) -> str:
"""脱敏邮箱"""
if '@' not in email:
return email
local, domain = email.split('@', 1)
masked_local = local[0] + '***' + local[-1:] if len(local) > 3 else '***'
return f"{masked_local}@{domain}"
def _mask_phone(self, phone: str) -> str:
"""脱敏电话号码"""
digits = re.sub(r'\D', '', phone)
if len(digits) >= 10:
return f"***-***-{digits[-4:]}"
else:
return '***-***'
def _mask_ssn(self, ssn: str) -> str:
"""脱敏 SSN"""
digits = re.sub(r'\D', '', ssn)
if len(digits) == 9:
return f"***-**-{digits[-4:]}"
else:
return '***-**-****'
def _mask_credit_card(self, card: str) -> str:
"""脱敏信用卡号"""
digits = re.sub(r'\D', '', card)
if len(digits) >= 13:
return f"****-****-****-{digits[-4:]}"
else:
return '****-****-****-****'
def _detect_data_type(self, data: str) -> str:
"""检测数据类型"""
if '@' in data and '.' in data.split('@')[1]:
return 'email'
elif re.match(r'^\d{3}-\d{2}-\d{4}$', data):
return 'ssn'
elif re.match(r'^\d{16}$', re.sub(r'\D', '', data)):
return 'credit_card'
elif re.match(r'^\d{10}$', re.sub(r'\D', '', data)):
return 'phone'
else:
return 'unknown'

标记本节教程为已读

记录您的学习进度,方便后续查看。